import numpy as np
import copy
import os
import random
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from params import get_args
from env.env import JSP_Env
from model.REINFORCE import REINFORCE
from heuristic import *
import json
import time

MAX = float(1e6)

def train():
    print("start Training")
    for episode in range(0, args.episode):
        baselines = []
        rewards = []
        log_probs = []
        entropies = []
        action_probs = []

        batch_max_op_num = 0
        batch_max_machine_num = 0

        if episode % 1000 == 0:
            torch.save(policy.state_dict(), "./weight/{}/{}".format(args.date, str(episode)))

        avai_op = env.reset()
        while avai_op is None:
            avai_op = env.reset()

        MWKR_ms = (heuristic_makespan(copy.deepcopy(env), copy.deepcopy(avai_op), "MWKR"))

        while True:
            MWKR_baseline = heuristic_makespan(copy.deepcopy(env), copy.deepcopy(avai_op), "MWKR")
            baseline = MWKR_baseline - env.get_makespan()
            data = env.get_graph_data()
            action_idx, action_prob, log_prob, entropy = policy(avai_op, data)
            avai_op, reward, done = env.step(avai_op[action_idx])
            
            baselines.append(baseline)
            log_probs.append(log_prob)
            entropies.append(entropy)
            action_probs.append(action_prob)
            rewards.append(-reward)
            if done:
                optimizer.zero_grad()
                loss, policy_loss, entropy_loss = policy.calculate_loss(args.device, log_probs, entropies, baselines, rewards)
                loss.backward()
                optimizer.step()
                scheduler.step()

                print("Episode : {}".format(episode))
                ms = env.get_makespan()

                print("Date : {}\t\t Job : {} \t\tMachine : {} \t\tPolicy : {} \t\tImprove: {} \t\tMWKR : {}".format(
                    args.date, env.jsp_instance.job_num, env.jsp_instance.machine_num,
                    ms, MWKR_ms - ms, MWKR_ms
                ))
                break
    
if __name__ == '__main__':
    args = get_args()
    print(args)
    prepare_dirs = ['result', 'weight']
    for _dir in prepare_dirs:
        os.makedirs(_dir, exist_ok=True)
        os.makedirs(os.path.join(_dir, args.date), exist_ok=True)
    with open("./result/{}/args.json".format(args.date),"a") as outfile:
        json.dump(vars(args), outfile, indent=8)
    env = JSP_Env(args)

    policy = REINFORCE(args).to(args.device)
    optimizer = optim.Adam(policy.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.99)
    train()